package colite

import (
	"compress/gzip"
	"io"
	"io/ioutil"
	"math/rand"
	"net/http"
	"regexp"
	"strings"
	"sync"
	"time"
)

// httpBackend HTTP客户端
type httpBackend struct {
	// LimitRules 请求限制规则
	LimitRules []*LimitRule

	// Client HTTP客户端
	Client *http.Client

	lock *sync.RWMutex
}

// checkHeadersFunc 响应头处理函数(响应头已接收但响应主体没有读取)
type checkHeadersFunc func(statusCode int, header http.Header) bool

// Init 初始化httpBackend
func (h *httpBackend) Init(jar http.CookieJar) {
	rand.Seed(time.Now().UnixNano())
	h.Client = &http.Client{
		Jar:     jar, // Jar不能为空，否则不保存Cookie
		Timeout: 10 * time.Second,
	}
	h.lock = &sync.RWMutex{}
}

// Do 发送请求
func (h *httpBackend) Do(request *http.Request, bodySize int, checkHeadersFunc checkHeadersFunc) (*Response, error) {
	// 根据限制规则暂停程序
	r := h.GetMatchingRule(request.URL.Host)
	if r != nil {
		r.waitChan <- true
		defer func(r *LimitRule) {
			randomDelay := time.Duration(0)
			if r.RandomDelay != 0 {
				randomDelay = time.Duration(rand.Int63n(int64(r.RandomDelay)))
			}
			time.Sleep(r.Delay + randomDelay)
			<-r.waitChan
		}(r)
	}

	// 发送HTTP请求
	resp, err := h.Client.Do(request)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	if resp.Request != nil {
		*request = *resp.Request
	}

	// 执行响应头处理链
	if !checkHeadersFunc(resp.StatusCode, resp.Header) {
		return nil, ErrAbortedAfterHeaders
	}

	// 读取HTTP响应主体数据(解压)
	var bodyReader io.Reader = resp.Body
	if bodySize > 0 {
		bodyReader = io.LimitReader(bodyReader, int64(bodySize))
	}
	contentEncoding := strings.ToLower(resp.Header.Get("Content-Encoding"))
	if !resp.Uncompressed && (strings.Contains(contentEncoding, "gzip") || (contentEncoding == "" && strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "gzip"))) {
		bodyReader, err = gzip.NewReader(bodyReader)
		if err != nil {
			return nil, err
		}
		defer bodyReader.(*gzip.Reader).Close()
	}
	body, err := ioutil.ReadAll(bodyReader)
	if err != nil {
		return nil, err
	}

	return &Response{
		StatusCode: resp.StatusCode,
		Body:       body,
		Headers:    &resp.Header,
	}, nil
}

// GetMatchingRule 获取与指定域名匹配的限制规则
func (h *httpBackend) GetMatchingRule(domain string) *LimitRule {
	if h.LimitRules == nil {
		return nil
	}
	h.lock.RLock()
	defer h.lock.RUnlock()
	for _, r := range h.LimitRules {
		if r.Match(domain) {
			return r
		}
	}
	return nil
}

// Limit 设置限制规则并初始化
func (h *httpBackend) Limit(rule *LimitRule) error {
	h.lock.Lock()
	if h.LimitRules == nil {
		h.LimitRules = make([]*LimitRule, 0, 8)
	}
	h.LimitRules = append(h.LimitRules, rule)
	h.lock.Unlock()
	return rule.Init()
}

// Limits 设置限制规则组并初始化
func (h *httpBackend) Limits(rules []*LimitRule) error {
	for _, r := range rules {
		if err := h.Limit(r); err != nil {
			return err
		}
	}
	return nil
}

// LimitRule 请求限制规则
type LimitRule struct {
	// DomainRegexp 匹配域名的正则表达式
	DomainRegexp string

	// Delay 请求之间等待的时间
	Delay time.Duration

	// RandomDelay 请求之间额外随机等待的时间
	RandomDelay time.Duration

	// Parallelism 允许的最大并发请求数
	Parallelism int

	waitChan       chan bool
	compiledRegexp *regexp.Regexp
}

// Init 初始化LimitRule
func (r *LimitRule) Init() error {
	// 初始化通道
	waitChanSize := 1
	if r.Parallelism > 1 {
		waitChanSize = r.Parallelism
	}
	r.waitChan = make(chan bool, waitChanSize)

	// 初始化匹配域名的正则表达式
	hasPattern := false
	if r.DomainRegexp != "" {
		c, err := regexp.Compile(r.DomainRegexp)
		if err != nil {
			return err
		}
		r.compiledRegexp = c
		hasPattern = true
	}
	if !hasPattern {
		return ErrNoPattern
	}
	return nil
}

// Match 检测指定域名是否符合限制规则
func (r *LimitRule) Match(domain string) bool {
	match := false
	if r.compiledRegexp != nil && r.compiledRegexp.MatchString(domain) {
		match = true
	}
	return match
}
